from vlmeval.smp import *
from vlmeval.api.base import BaseAPI

headers = "Content-Type: application/json"


class GeminiWrapper(BaseAPI):

    is_api: bool = True

    def __init__(
        self,
        model: str = "gemini-1.0-pro",
        retry: int = 5,
        wait: int = 5,
        key: str = None,
        verbose: bool = True,
        temperature: float = 0.0,
        system_prompt: str = None,
        max_tokens: int = 1024,
        proxy: str = None,
        backend="genai",
        project_id="vlmeval",
        **kwargs,
    ):

        self.model = model
        self.fail_msg = "Failed to obtain answer via API. "
        self.max_tokens = max_tokens
        self.temperature = temperature
        if key is None:
            key = os.environ.get("GOOGLE_API_KEY", None)
        # Try to load backend from environment variable
        be = os.environ.get("GOOGLE_API_BACKEND", None)
        if be is not None and be in ["genai", "vertex"]:
            backend = be

        assert backend in ["genai", "vertex"]
        if backend == "genai":
            # We have not evaluated Gemini-1.5 w. GenAI backend
            assert key is not None  # Vertex does not require API Key

        self.backend = backend
        self.project_id = project_id
        self.api_key = key

        if proxy is not None:
            proxy_set(proxy)
        super().__init__(
            wait=wait,
            retry=retry,
            system_prompt=system_prompt,
            verbose=verbose,
            **kwargs,
        )

    def build_msgs_genai(self, inputs):
        messages = [] if self.system_prompt is None else [self.system_prompt]
        for inp in inputs:
            if inp["type"] == "text":
                messages.append(inp["value"])
            elif inp["type"] == "image":
                messages.append(Image.open(inp["value"]))
        return messages

    def build_msgs_vertex(self, inputs):
        from vertexai.generative_models import Part, Image

        messages = [] if self.system_prompt is None else [self.system_prompt]
        for inp in inputs:
            if inp["type"] == "text":
                messages.append(inp["value"])
            elif inp["type"] == "image":
                messages.append(Part.from_image(Image.load_from_file(inp["value"])))
        return messages

    def generate_inner(self, inputs, **kwargs) -> str:
        if self.backend == "genai":
            import google.generativeai as genai

            assert isinstance(inputs, list)
            pure_text = np.all([x["type"] == "text" for x in inputs])
            genai.configure(api_key=self.api_key)

            if pure_text and self.model == "gemini-1.0-pro":
                model = genai.GenerativeModel("gemini-1.0-pro")
            else:
                model = genai.GenerativeModel(self.model)

            messages = self.build_msgs_genai(inputs)
            gen_config = dict(
                max_output_tokens=self.max_tokens, temperature=self.temperature
            )
            gen_config.update(kwargs)
            try:
                answer = model.generate_content(
                    messages,
                    generation_config=genai.types.GenerationConfig(**gen_config),
                ).text
                return 0, answer, "Succeeded! "
            except Exception as err:
                if self.verbose:
                    self.logger.error(f"{type(err)}: {err}")
                    self.logger.error(f"The input messages are {inputs}.")

                return -1, "", ""
        elif self.backend == "vertex":
            import vertexai
            from vertexai.generative_models import GenerativeModel

            vertexai.init(project=self.project_id, location="us-central1")
            model_name = (
                "gemini-1.0-pro-vision"
                if self.model == "gemini-1.0-pro"
                else self.model
            )
            model = GenerativeModel(model_name=model_name)
            messages = self.build_msgs_vertex(inputs)
            try:
                resp = model.generate_content(messages)
                answer = resp.text
                return 0, answer, "Succeeded! "
            except Exception as err:
                if self.verbose:
                    self.logger.error(f"{type(err)}: {err}")
                    self.logger.error(f"The input messages are {inputs}.")

                return -1, "", ""


class GeminiProVision(GeminiWrapper):

    def generate(self, message, dataset=None):
        return super(GeminiProVision, self).generate(message)
